# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for tensorflow.python.ops.op_def_library."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from google.protobuf import text_format

from tensorflow.core.framework import op_def_pb2
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.framework.op_def_library import OpDefLibrary
from tensorflow.python.platform import googletest


# NOTE(mrry): Dummy shape registrations for ops used in the tests.
ops.RegisterShape("Attr")(None)
ops.RegisterShape("AttrBool")(None)
ops.RegisterShape("AttrBoolList")(None)
ops.RegisterShape("AttrDefault")(None)
ops.RegisterShape("AttrEmptyListDefault")(None)
ops.RegisterShape("AttrEnum")(None)
ops.RegisterShape("AttrEnumList")(None)
ops.RegisterShape("AttrFloat")(None)
ops.RegisterShape("AttrListDefault")(None)
ops.RegisterShape("AttrListMin")(None)
ops.RegisterShape("AttrMin")(None)
ops.RegisterShape("AttrShape")(None)
ops.RegisterShape("AttrShapeList")(None)
ops.RegisterShape("AttrPartialShape")(None)
ops.RegisterShape("AttrPartialShapeList")(None)
ops.RegisterShape("Binary")(None)
ops.RegisterShape("ComplexStruct")(None)
ops.RegisterShape("InPolymorphicTwice")(None)
ops.RegisterShape("MixedStruct")(None)
ops.RegisterShape("NInPolymorphicTwice")(None)
ops.RegisterShape("NInTwice")(None)
ops.RegisterShape("NInTwoTypeVariables")(None)
ops.RegisterShape("NIntsIn")(None)
ops.RegisterShape("NIntsOut")(None)
ops.RegisterShape("NIntsOutDefault")(None)
ops.RegisterShape("NPolymorphicIn")(None)
ops.RegisterShape("NPolymorphicOut")(None)
ops.RegisterShape("NPolymorphicOutDefault")(None)
ops.RegisterShape("NPolymorphicRestrictIn")(None)
ops.RegisterShape("NPolymorphicRestrictOut")(None)
ops.RegisterShape("OutT")(None)
ops.RegisterShape("OutTypeList")(None)
ops.RegisterShape("OutTypeListRestrict")(None)
ops.RegisterShape("Polymorphic")(None)
ops.RegisterShape("PolymorphicDefaultOut")(None)
ops.RegisterShape("PolymorphicOut")(None)
ops.RegisterShape("RefIn")(None)
ops.RegisterShape("RefOut")(None)
ops.RegisterShape("ReservedAttr")(None)
ops.RegisterShape("ReservedInput")(None)
ops.RegisterShape("Restrict")(None)
ops.RegisterShape("Simple")(None)
ops.RegisterShape("SimpleStruct")(None)
ops.RegisterShape("TwoRefsIn")(None)
ops.RegisterShape("TypeList")(None)
ops.RegisterShape("TypeListRestrict")(None)
ops.RegisterShape("TypeListTwice")(None)


class OpDefLibraryTest(test_util.TensorFlowTestCase):

  def setUp(self):
    self._lib = OpDefLibrary()
    self._g = ops.Graph()
    self._default_graph_controller = self._g.as_default()
    self._default_graph_controller.__enter__()
    self._add_op("name: 'Simple' input_arg { name: 'a' type: DT_INT32 } "
                 "output_arg { name: 'out' type: DT_FLOAT }")
    self._add_op("name: 'OutT' output_arg { name: 'a' type_attr: 'T' } "
                 "attr { name: 'T' type: 'type' }")

  def tearDown(self):
    self._default_graph_controller.__exit__(None, None, None)

  def _add_op(self, ascii):
    op_def = op_def_pb2.OpDef()
    text_format.Merge(ascii, op_def)
    self._lib.add_op(op_def)

  def Tensor(self, t, name="in"):
    return self._lib.apply_op("OutT", T=t, name=name)

  def testNoRegisteredOpFails(self):
    with self.assertRaises(RuntimeError) as cm:
      self._lib.apply_op("unknown")
    self.assertEqual(str(cm.exception), "Unrecognized Op name unknown")

  def testAddOpValidation(self):
    with self.assertRaises(TypeError) as cm:
      self._add_op("name: 'MissingTypeAttr' "
                   "input_arg { name: 'a' type_attr: 'T' } ")
    self.assertEqual(str(cm.exception),
                     "Inconsistent OpDef for 'MissingTypeAttr', "
                     "missing attr 'T'")

    with self.assertRaises(TypeError) as cm:
      self._add_op("name: 'BadTypeAttr' "
                   "output_arg { name: 'a' type_attr: 'T' } "
                   "attr { name: 'T' type: 'int' }")
    self.assertEqual(
        str(cm.exception),
        "Attr 'T' of 'BadTypeAttr' used as a type_attr but has type int")

    with self.assertRaises(TypeError) as cm:
      self._add_op("name: 'MissingNumberAttr' "
                   "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } ")
    self.assertEqual(str(cm.exception),
                     "Inconsistent OpDef for 'MissingNumberAttr', "
                     "missing attr 'N'")

    with self.assertRaises(TypeError) as cm:
      self._add_op("name: 'BadNumberAttr' "
                   "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
                   "attr { name: 'N' type: 'type' }")
    self.assertEqual(
        str(cm.exception),
        "Attr 'N' of 'BadNumberAttr' used as a number_attr but has type type")

    with self.assertRaises(TypeError) as cm:
      self._add_op("name: 'TwoTypesA' "
                   "input_arg { name: 'a' type: DT_INT32 type_attr: 'T' } "
                   "attr { name: 'T' type: 'type' }")
    self.assertEqual(str(cm.exception),
                     "Arg 'a' of 'TwoTypesA' must have one type field not 2")

    with self.assertRaises(TypeError) as cm:
      self._add_op("name: 'TwoTypesB' "
                   "input_arg { name: 'a' type: DT_INT32 type_list_attr: 'T' } "
                   "attr { name: 'T' type: 'list(type)' }")
    self.assertEqual(str(cm.exception),
                     "Arg 'a' of 'TwoTypesB' must have one type field not 2")

    with self.assertRaises(TypeError) as cm:
      self._add_op("name: 'ThreeTypes' "
                   "input_arg { name: 'a' type: DT_INT32 type_attr: 'T' "
                   "type_list_attr: 'U' } "
                   "attr { name: 'T' type: 'type' } "
                   "attr { name: 'U' type: 'list(type)' }")
    self.assertEqual(str(cm.exception),
                     "Arg 'a' of 'ThreeTypes' must have one type field not 3")

    with self.assertRaises(TypeError) as cm:
      self._add_op("name: 'NoTypes' output_arg { name: 'a' } ")
    self.assertEqual(str(cm.exception),
                     "Arg 'a' of 'NoTypes' must have one type field not 0")

  def testSimple(self):
    out = self._lib.apply_op("Simple", a=3)
    self.assertEqual(dtypes.float32, out.dtype)
    self.assertProtoEquals("""
      name: 'Simple' op: 'Simple' input: 'Simple/a'
      """, out.op.node_def)

    out = self._lib.apply_op("Simple", a=4)
    self.assertProtoEquals("""
      name: 'Simple_1' op: 'Simple' input: 'Simple_1/a'
      """, out.op.node_def)

    out = self._lib.apply_op("Simple", a=5, name="named")
    self.assertProtoEquals("""
      name: 'named' op: 'Simple' input: 'named/a'
      """, out.op.node_def)

    out = self._lib.apply_op("Simple", a=[[1, 2, 3], [4, 5, 6]], name="two_d")
    self.assertProtoEquals("""
      name: 'two_d' op: 'Simple' input: 'two_d/a'
      """, out.op.node_def)

  def testSimpleFailures(self):
    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("Simple", a="Bad string")
    self.assertEqual(str(cm.exception),
                     "Expected int32, got 'Bad string' of type 'str' instead.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("Simple", a=self.Tensor(dtypes.string))
    self.assertEqual(str(cm.exception),
                     "Input 'a' of 'Simple' Op has type string "
                     "that does not match expected type of int32.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("Simple", a=6, extra="bogus")
    self.assertEqual(str(cm.exception),
                     "apply_op() got unexpected keyword arguments: extra")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("Simple", a=6, extra1="bogus", extra2="also_bogus")
    self.assertEqual(str(cm.exception),
                     "apply_op() got unexpected keyword arguments: extra1, "
                     "extra2")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("Simple")
    self.assertEqual(str(cm.exception), "No argument for input a")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("Simple", wrong=7)
    self.assertEqual(str(cm.exception), "No argument for input a")

  def testReservedInput(self):
    self._add_op("name: 'ReservedInput' "
                 "input_arg { name: 'input' type: DT_INT32 } ")
    op = self._lib.apply_op("ReservedInput", input_=7, name="x")
    self.assertProtoEquals("""
      name: 'x' op: 'ReservedInput' input: 'x/input'
      """, op.node_def)

  def testPolymorphic(self):
    self._add_op("name: 'Polymorphic' "
                 "input_arg { name: 'a' type_attr: 'T' } "
                 "output_arg { name: 'out' type_attr: 'T' } "
                 "attr { name: 'T' type: 'type' }")

    out = self._lib.apply_op("Polymorphic", a=7, name="p")
    self.assertEqual(dtypes.int32, out.dtype)
    self.assertProtoEquals("""
      name: 'p' op: 'Polymorphic' input: 'p/a'
      attr { key: 'T' value { type: DT_INT32 } }
      """, out.op.node_def)

    out = self._lib.apply_op("Polymorphic", a="s", name="q")
    self.assertEqual(dtypes.string, out.dtype)
    self.assertProtoEquals("""
      name: 'q' op: 'Polymorphic' input: 'q/a'
      attr { key: 'T' value { type: DT_STRING } }
      """, out.op.node_def)

    out = self._lib.apply_op("Polymorphic", a=["s", "t", "u"], name="r")
    self.assertEqual(dtypes.string, out.dtype)
    self.assertProtoEquals("""
      name: 'r' op: 'Polymorphic' input: 'r/a'
      attr { key: 'T' value { type: DT_STRING } }
      """, out.op.node_def)

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("Polymorphic", a="s", T=dtypes.string)
    self.assertEqual(str(cm.exception),
                     "Should not specify value for inferred attr 'T'.")

  def testPolymorphicOut(self):
    self._add_op("name: 'PolymorphicOut' "
                 "output_arg { name: 'out' type_attr: 'T' } "
                 "attr { name: 'T' type: 'type' }")

    out = self._lib.apply_op("PolymorphicOut", T=dtypes.int32, name="p")
    self.assertEqual(dtypes.int32, out.dtype)
    self.assertProtoEquals("""
      name: 'p' op: 'PolymorphicOut'
      attr { key: 'T' value { type: DT_INT32 } }
      """, out.op.node_def)

    out = self._lib.apply_op("PolymorphicOut", T=dtypes.bool, name="q")
    self.assertEqual(dtypes.bool, out.dtype)
    self.assertProtoEquals("""
      name: 'q' op: 'PolymorphicOut'
      attr { key: 'T' value { type: DT_BOOL } }
      """, out.op.node_def)

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("PolymorphicOut")
    self.assertEqual(str(cm.exception),
                     "No argument for attr T")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("PolymorphicOut", T=None)
    self.assertEqual(str(cm.exception),
                     "Expected DataType for argument 'T' not None.")

  def testPolymorphicDefaultOut(self):
    self._add_op("name: 'PolymorphicDefaultOut' "
                 "output_arg { name: 'out' type_attr: 'T' } "
                 "attr { name: 'T' type: 'type' "
                 "  default_value { type: DT_STRING } }")

    out = self._lib.apply_op("PolymorphicDefaultOut", T=None, name="p")
    self.assertEqual(dtypes.string, out.dtype)
    self.assertProtoEquals("""
      name: 'p' op: 'PolymorphicDefaultOut'
      attr { key: 'T' value { type: DT_STRING } }
      """, out.op.node_def)

    out = self._lib.apply_op("PolymorphicDefaultOut", T=dtypes.bool, name="q")
    self.assertEqual(dtypes.bool, out.dtype)
    self.assertProtoEquals("""
      name: 'q' op: 'PolymorphicDefaultOut'
      attr { key: 'T' value { type: DT_BOOL } }
      """, out.op.node_def)

  def testBinary(self):
    self._add_op("name: 'Binary' "
                 "input_arg { name: 'a' type_attr: 'T' } "
                 "input_arg { name: 'b' type_attr: 'T' } "
                 "output_arg { name: 'out' type_attr: 'T' } "
                 "attr { name: 'T' type: 'type' }")

    out = self._lib.apply_op("Binary", a=8, b=9, name="b")
    self.assertEqual(dtypes.int32, out.dtype)
    self.assertProtoEquals("""
      name: 'b' op: 'Binary' input: 'b/a' input: 'b/b'
      attr { key: 'T' value { type: DT_INT32 } }
      """, out.op.node_def)

    out = self._lib.apply_op("Binary", a="left", b="right", name="c")
    self.assertEqual(dtypes.string, out.dtype)
    self.assertProtoEquals("""
      name: 'c' op: 'Binary' input: 'c/a' input: 'c/b'
      attr { key: 'T' value { type: DT_STRING } }
      """, out.op.node_def)

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("Binary", a="left", b=12)
    self.assertEqual(str(cm.exception),
                     "Expected string, got 12 of type 'int' instead.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("Binary",
                         a=self.Tensor(dtypes.string),
                         b=self.Tensor(dtypes.int32))
    self.assertEqual(str(cm.exception),
                     "Input 'b' of 'Binary' Op has type int32 "
                     "that does not match type string of argument 'a'.")

  def testRestrict(self):
    self._add_op("name: 'Restrict' "
                 "input_arg { name: 'a' type_attr: 'T' } "
                 "output_arg { name: 'out' type_attr: 'T' } "
                 "attr { name: 'T' type: 'type' allowed_values { list { "
                 "  type: DT_STRING type: DT_BOOL } } }")

    out = self._lib.apply_op("Restrict", a="foo", name="g")
    self.assertEqual(dtypes.string, out.dtype)
    self.assertProtoEquals("""
      name: 'g' op: 'Restrict' input: 'g/a'
      attr { key: 'T' value { type: DT_STRING } }
      """, out.op.node_def)

    out = self._lib.apply_op("Restrict", a=True, name="h")
    self.assertEqual(dtypes.bool, out.dtype)
    self.assertProtoEquals("""
      name: 'h' op: 'Restrict' input: 'h/a'
      attr { key: 'T' value { type: DT_BOOL } }
      """, out.op.node_def)

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("Restrict", a=17)
    self.assertEqual(str(cm.exception),
                     "DataType int32 for attr 'T' "
                     "not in list of allowed values: "
                     "string, bool")

  def testTypeList(self):
    self._add_op("name: 'TypeList' "
                 "input_arg { name: 'a' type_list_attr: 'T' } "
                 "attr { name: 'T' type: 'list(type)' }")

    op = self._lib.apply_op("TypeList", a=["foo"], name="z")
    self.assertProtoEquals("""
      name: 'z' op: 'TypeList' input: 'z/a_0'
      attr { key: 'T' value { list { type: DT_STRING } } }
      """, op.node_def)

    op = self._lib.apply_op("TypeList", a=[True, 12], name="y")
    self.assertProtoEquals("""
      name: 'y' op: 'TypeList' input: 'y/a_0' input: 'y/a_1'
      attr { key: 'T' value { list { type: DT_BOOL type: DT_INT32 } } }
      """, op.node_def)

    op = self._lib.apply_op("TypeList", a=[], name="empty")
    self.assertProtoEquals("""
      name: 'empty' op: 'TypeList' attr { key: 'T' value { list { } } }
      """, op.node_def)

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("TypeList", a=17)
    self.assertStartsWith(str(cm.exception),
                          "Expected list for 'a' "
                          "argument to 'TypeList' Op, not ")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("TypeList", a=[self.Tensor(dtypes.int32), None])
    self.assertStartsWith(str(cm.exception),
                          "Tensors in list passed to 'a' of 'TypeList' Op "
                          "have types [int32, <NOT CONVERTIBLE TO TENSOR>]")

  def testTypeListTwice(self):
    self._add_op("name: 'TypeListTwice' "
                 "input_arg { name: 'a' type_list_attr: 'T' } "
                 "input_arg { name: 'b' type_list_attr: 'T' } "
                 "attr { name: 'T' type: 'list(type)' }")

    op = self._lib.apply_op("TypeListTwice",
                            a=["foo", True],
                            b=["bar", False],
                            name="z")
    self.assertProtoEquals("""
      name: 'z' op: 'TypeListTwice'
      input: 'z/a_0' input: 'z/a_1' input: 'z/b_0' input: 'z/b_1'
      attr { key: 'T' value { list { type: DT_STRING type: DT_BOOL } } }
      """, op.node_def)

    op = self._lib.apply_op("TypeListTwice", a=[], b=[], name="empty")
    self.assertProtoEquals("""
      name: 'empty' op: 'TypeListTwice' attr { key: 'T' value { list { } } }
      """, op.node_def)

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("TypeListTwice", a=["foo", True], b=["bar", 6])
    self.assertEqual(str(cm.exception),
                     "Input 'b' of 'TypeListTwice' Op has type list of "
                     "string, int32 that does not match type list "
                     "string, bool of argument 'a'.")

  def testOutTypeList(self):
    self._add_op("name: 'OutTypeList' "
                 "output_arg { name: 'out' type_list_attr: 'T' } "
                 "attr { name: 'T' type: 'list(type)' }")

    out, = self._lib.apply_op("OutTypeList", T=[dtypes.float32], name="x")
    self.assertEqual(dtypes.float32, out.dtype)
    self.assertProtoEquals("""
      name: 'x' op: 'OutTypeList'
      attr { key: 'T' value { list { type: DT_FLOAT } } }
      """, out.op.node_def)

    out1, out2 = self._lib.apply_op("OutTypeList",
                                    T=[dtypes.int32, dtypes.bool],
                                    name="w")
    self.assertEqual(dtypes.int32, out1.dtype)
    self.assertEqual(dtypes.bool, out2.dtype)
    self.assertProtoEquals("""
      name: 'w' op: 'OutTypeList'
      attr { key: 'T' value { list { type: DT_INT32 type: DT_BOOL } } }
      """, out1.op.node_def)

    out = self._lib.apply_op("OutTypeList", T=[], name="empty")
    self.assertEqual([], out)

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("OutTypeList", T=dtypes.int32)
    self.assertEqual(str(cm.exception), "Expected list for attr T")

  def testTypeListRestrict(self):
    self._add_op("name: 'TypeListRestrict' "
                 "input_arg { name: 'a' type_list_attr: 'T' } "
                 "attr { name: 'T' type: 'list(type)' allowed_values { list { "
                 "  type: DT_STRING type: DT_BOOL } } }")

    op = self._lib.apply_op("TypeListRestrict", a=["foo", False], name="v")
    self.assertProtoEquals("""
      name: 'v' op: 'TypeListRestrict' input: 'v/a_0' input: 'v/a_1'
      attr { key: 'T' value { list { type: DT_STRING type: DT_BOOL } } }
      """, op.node_def)

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("TypeListRestrict", a=[True, 12])
    self.assertEqual(str(cm.exception),
                     "DataType int32 for attr 'T' "
                     "not in list of allowed values: string, bool")

  def testOutTypeListRestrict(self):
    self._add_op("name: 'OutTypeListRestrict' "
                 "output_arg { name: 'out' type_list_attr: 't' } "
                 "attr { name: 't' type: 'list(type)' allowed_values { list { "
                 "  type: DT_STRING type: DT_BOOL } } }")

    out1, out2 = self._lib.apply_op("OutTypeListRestrict",
                                    t=[dtypes.bool, dtypes.string],
                                    name="u")
    self.assertEqual(dtypes.bool, out1.dtype)
    self.assertEqual(dtypes.string, out2.dtype)
    self.assertProtoEquals("""
      name: 'u' op: 'OutTypeListRestrict'
      attr { key: 't' value { list { type: DT_BOOL type: DT_STRING } } }
      """, out1.op.node_def)

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("OutTypeListRestrict", t=[dtypes.string, dtypes.int32])
    self.assertEqual(str(cm.exception),
                     "DataType int32 for attr 't' "
                     "not in list of allowed values: string, bool")

  def testAttr(self):
    self._add_op("name: 'Attr' attr { name: 'a' type: 'int' }")
    op = self._lib.apply_op("Attr", a=12, name="t")
    self.assertProtoEquals("""
      name: 't' op: 'Attr' attr { key: 'a' value { i: 12 } }
      """, op.node_def)

    op = self._lib.apply_op("Attr", a=tensor_shape.Dimension(13), name="u")
    self.assertProtoEquals("""
      name: 'u' op: 'Attr' attr { key: 'a' value { i: 13 } }
      """, op.node_def)

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("Attr", a="bad")
    self.assertEqual(str(cm.exception),
                     "Expected int for argument 'a' not 'bad'.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("Attr", a=[12])
    self.assertEqual(str(cm.exception),
                     "Expected int for argument 'a' not [12].")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("Attr", a=None)
    self.assertEqual(str(cm.exception),
                     "Expected int for argument 'a' not None.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("Attr")
    self.assertEqual(str(cm.exception), "No argument for attr a")

  def testAttrFloat(self):
    self._add_op("name: 'AttrFloat' attr { name: 'a' type: 'float' }")

    op = self._lib.apply_op("AttrFloat", a=1.2, name="t")
    self.assertProtoEquals("""
      name: 't' op: 'AttrFloat' attr { key: 'a' value { f: 1.2 } }
      """, op.node_def)

    op = self._lib.apply_op("AttrFloat", a=12, name="u")
    self.assertProtoEquals("""
      name: 'u' op: 'AttrFloat' attr { key: 'a' value { f: 12 } }
      """, op.node_def)

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("AttrFloat", a="bad")
    self.assertEqual(str(cm.exception),
                     "Expected float for argument 'a' not 'bad'.")

  def testAttrBool(self):
    self._add_op("name: 'AttrBool' attr { name: 'a' type: 'bool' }")

    op = self._lib.apply_op("AttrBool", a=True, name="t")
    self.assertProtoEquals("""
      name: 't' op: 'AttrBool' attr { key: 'a' value { b: true } }
      """, op.node_def)

    op = self._lib.apply_op("AttrBool", a=False, name="u")
    self.assertProtoEquals("""
      name: 'u' op: 'AttrBool' attr { key: 'a' value { b: false } }
      """, op.node_def)

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("AttrBool", a=0)
    self.assertEqual(str(cm.exception),
                     "Expected bool for argument 'a' not 0.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("AttrBool", a=1)
    self.assertEqual(str(cm.exception),
                     "Expected bool for argument 'a' not 1.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("AttrBool", a=[])
    self.assertEqual(str(cm.exception),
                     "Expected bool for argument 'a' not [].")

  def testAttrBoolList(self):
    self._add_op("name: 'AttrBoolList' attr { name: 'a' type: 'list(bool)' }")

    op = self._lib.apply_op("AttrBoolList", a=[True, False, True], name="t")
    self.assertProtoEquals("""
      name: 't' op: 'AttrBoolList'
      attr { key: 'a' value { list { b: true b: false b:true } } }
      """, op.node_def)

    op = self._lib.apply_op("AttrBoolList", a=[], name="u")
    self.assertProtoEquals("""
      name: 'u' op: 'AttrBoolList' attr { key: 'a' value { list { } } }
      """, op.node_def)

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("AttrBoolList", a=[0])
    self.assertEqual(str(cm.exception),
                     "Expected bool for argument 'a' not 0.")

  def testAttrMin(self):
    self._add_op("name: 'AttrMin' attr { name: 'a' type: 'int' "
                 "has_minimum: true minimum: 5 }")
    op = self._lib.apply_op("AttrMin", a=12, name="s")
    self.assertProtoEquals("""
      name: 's' op: 'AttrMin' attr { key: 'a' value { i: 12 } }
      """, op.node_def)

    with self.assertRaises(ValueError) as cm:
      self._lib.apply_op("AttrMin", a=2)
    self.assertEqual(str(cm.exception),
                     "Attr 'a' of 'AttrMin' Op passed 2 less than minimum 5.")

  def testAttrListMin(self):
    self._add_op("name: 'AttrListMin' attr { name: 'a' type: 'list(int)' "
                 "has_minimum: true minimum: 2 }")

    op = self._lib.apply_op("AttrListMin", a=[1, 2], name="r")
    self.assertProtoEquals("""
      name: 'r' op: 'AttrListMin'
      attr { key: 'a' value { list { i: 1 i: 2 } } }
      """, op.node_def)

    with self.assertRaises(ValueError) as cm:
      self._lib.apply_op("AttrListMin", a=[17])
    self.assertEqual(str(cm.exception),
                     "Attr 'a' of 'AttrListMin' Op "
                     "passed list of length 1 less than minimum 2.")

  def testAttrEnum(self):
    self._add_op("name: 'AttrEnum' "
                 "attr { name: 'a' type: 'string' "
                 "  allowed_values { list { s: 'apples' s: 'oranges' } } }")

    op = self._lib.apply_op("AttrEnum", a="oranges", name="e")
    self.assertProtoEquals("""
      name: 'e' op: 'AttrEnum' attr { key: 'a' value { s: 'oranges' } }
      """, op.node_def)

    with self.assertRaises(ValueError) as cm:
      self._lib.apply_op("AttrEnum", a="invalid")
    self.assertEqual(str(cm.exception),
                     'Attr \'a\' of \'AttrEnum\' Op '
                     'passed string \'invalid\' not in: '
                     '"apples", "oranges".')

  def testAttrEnumList(self):
    self._add_op("name: 'AttrEnumList' "
                 "attr { name: 'a' type: 'list(string)' "
                 "  allowed_values { list { s: 'apples' s: 'oranges' } } }")

    op = self._lib.apply_op("AttrEnumList", a=["oranges", "apples"], name="f")
    self.assertProtoEquals("""
      name: 'f' op: 'AttrEnumList'
      attr { key: 'a' value { list { s: 'oranges' s: 'apples' } } }
      """, op.node_def)

    with self.assertRaises(ValueError) as cm:
      self._lib.apply_op("AttrEnumList", a=["apples", "invalid", "oranges"])
    self.assertEqual(str(cm.exception),
                     'Attr \'a\' of \'AttrEnumList\' Op '
                     'passed string \'invalid\' not '
                     'in: "apples", "oranges".')

  def testAttrShape(self):
    self._add_op("name: 'AttrShape' attr { name: 'a' type: 'shape' }")

    op = self._lib.apply_op("AttrShape", a=[5], name="s1")
    self.assertProtoEquals("""
      name: 's1' op: 'AttrShape'
      attr { key: 'a' value { shape { dim { size: 5 } } } }
      """, op.node_def)

    op = self._lib.apply_op("AttrShape", a=(4, 3, 2), name="s2")
    self.assertProtoEquals("""
      name: 's2' op: 'AttrShape'
      attr { key: 'a' value {
        shape { dim { size: 4 } dim { size: 3 } dim { size: 2 } } } }
      """, op.node_def)

    op = self._lib.apply_op(
        "AttrShape", a=tensor_shape.TensorShape([3, 2]), name="s3")
    self.assertProtoEquals("""
      name: 's3' op: 'AttrShape'
      attr { key: 'a' value {
        shape { dim { size: 3 } dim { size: 2 } } } }
      """, op.node_def)

    op = self._lib.apply_op("AttrShape", a=[], name="s4")
    self.assertProtoEquals("""
      name: 's4' op: 'AttrShape' attr { key: 'a' value { shape { } } }
      """, op.node_def)

    shape = tensor_shape_pb2.TensorShapeProto()
    shape.dim.add().size = 6
    shape.dim.add().size = 3
    op = self._lib.apply_op("AttrShape", a=shape, name="s5")
    self.assertProtoEquals("""
      name: 's5' op: 'AttrShape'
      attr { key: 'a' value { shape { dim { size: 6 } dim { size: 3 } } } }
      """, op.node_def)

    # TODO(josh11b): Re-enable this test once we stop promoting scalars to shapes.
    # with self.assertRaises(TypeError) as cm:
    #   self._lib.apply_op("AttrShape", a=5)
    # self.assertEqual(str(cm.exception),
    #                  "Don't know how to convert 5 to a TensorShapeProto for "
    #                  "argument 'a'")

    with self.assertRaises(TypeError):
      self._lib.apply_op("AttrShape", a="ABC")

  def testAttrShapeList(self):
    self._add_op("name: 'AttrShapeList' attr { name: 'a' type: 'list(shape)' }")

    op = self._lib.apply_op("AttrShapeList", a=[[3, 2], [6, 5, 4]], name="sl")
    self.assertProtoEquals("""
      name: 'sl' op: 'AttrShapeList'
      attr { key: 'a' value { list {
        shape { dim { size: 3 } dim { size: 2 } }
        shape { dim { size: 6 } dim { size: 5 } dim { size: 4 } } } } }
      """, op.node_def)

    op = self._lib.apply_op("AttrShapeList", a=[], name="esl")
    self.assertProtoEquals("""
      name: 'esl' op: 'AttrShapeList' attr { key: 'a' value { list { } } }
      """, op.node_def)

  def testAttrPartialShape(self):
    self._add_op(
        "name: 'AttrPartialShape' attr { name: 'a' type: 'shape' }")

    op = self._lib.apply_op("AttrPartialShape", a=[5], name="s1")
    self.assertProtoEquals("""
      name: 's1' op: 'AttrPartialShape'
      attr { key: 'a' value { shape { dim { size: 5 } } } }
      """, op.node_def)

    op = self._lib.apply_op("AttrPartialShape", a=(4, None, 2), name="s2")
    self.assertProtoEquals("""
      name: 's2' op: 'AttrPartialShape'
      attr { key: 'a' value {
        shape { dim { size: 4 } dim { size: -1 } dim { size: 2 } } } }
      """, op.node_def)

    op = self._lib.apply_op(
        "AttrPartialShape", a=tensor_shape.TensorShape([3, None]), name="s3")
    self.assertProtoEquals("""
      name: 's3' op: 'AttrPartialShape'
      attr { key: 'a' value {
        shape { dim { size: 3 } dim { size: -1 } } } }
      """, op.node_def)

    op = self._lib.apply_op("AttrPartialShape", a=[], name="s4")
    self.assertProtoEquals("""
      name: 's4' op: 'AttrPartialShape'
      attr { key: 'a' value { shape { } } }
      """, op.node_def)

    shape = tensor_shape_pb2.TensorShapeProto()
    shape.dim.add().size = -1
    shape.dim.add().size = 3
    op = self._lib.apply_op("AttrPartialShape", a=shape, name="s5")
    self.assertProtoEquals("""
      name: 's5' op: 'AttrPartialShape'
      attr { key: 'a' value {
        shape { dim { size: -1 } dim { size: 3 } } } }
      """, op.node_def)

    # TODO(ebrevdo): Re-enable once we stop promoting scalars to shapes.
    # with self.assertRaises(TypeError) as cm:
    #   self._lib.apply_op("AttrPartialShape", a=5)
    # self.assertEqual(str(cm.exception),
    #                  "Don't know how to convert 5 to a TensorShapeProto for "
    #                  "argument 'a'")

    with self.assertRaises(TypeError):
      self._lib.apply_op("AttrPartialShape", a="ABC")

  def testAttrPartialShapeList(self):
    self._add_op("""
      name: 'AttrPartialShapeList'
      attr { name: 'a' type: 'list(shape)' }
    """)

    op = self._lib.apply_op(
        "AttrPartialShapeList", a=[[3, 2], [6, None, 4]], name="sl")
    self.assertProtoEquals("""
      name: 'sl' op: 'AttrPartialShapeList'
      attr { key: 'a' value { list {
        shape { dim { size: 3 } dim { size: 2 } }
        shape { dim { size: 6 } dim { size: -1 } dim { size: 4 } } } } }
      """, op.node_def)

    op = self._lib.apply_op("AttrPartialShapeList", a=[], name="esl")
    self.assertProtoEquals("""
      name: 'esl' op: 'AttrPartialShapeList' attr {
        key: 'a' value { list { } } }
      """, op.node_def)

  def testAttrDefault(self):
    self._add_op("name: 'AttrDefault' "
                 "attr { name: 'a' type: 'string' "
                 "  default_value { s: 'banana' } }")

    op = self._lib.apply_op("AttrDefault", a=None, name="d")
    self.assertProtoEquals("""
      name: 'd' op: 'AttrDefault' attr { key: 'a' value { s: 'banana' } }
      """, op.node_def)

    op = self._lib.apply_op("AttrDefault", a="kiwi", name="c")
    self.assertProtoEquals("""
      name: 'c' op: 'AttrDefault' attr { key: 'a' value { s: 'kiwi' } }
      """, op.node_def)

  def testAttrListDefault(self):
    self._add_op("name: 'AttrListDefault' "
                 "attr { name: 'a' type: 'list(int)' "
                 "  default_value { list { i: 5 i: 15 } } }")

    op = self._lib.apply_op("AttrListDefault", a=None, name="b")
    self.assertProtoEquals("""
      name: 'b' op: 'AttrListDefault'
      attr { key: 'a' value { list { i: 5 i: 15 } } }
      """, op.node_def)

    op = self._lib.apply_op("AttrListDefault", a=[3], name="a")
    self.assertProtoEquals("""
      name: 'a' op: 'AttrListDefault'
      attr { key: 'a' value { list { i: 3 } } }
      """, op.node_def)

    op = self._lib.apply_op("AttrListDefault", a=[], name="empty")
    self.assertProtoEquals("""
      name: 'empty' op: 'AttrListDefault'
      attr { key: 'a' value { list { } } }
      """, op.node_def)

  def testAttrEmptyListDefault(self):
    self._add_op("name: 'AttrEmptyListDefault' "
                 "attr { name: 'a' type: 'list(float)' "
                 "       default_value { list { } } }")

    op = self._lib.apply_op("AttrEmptyListDefault", a=None, name="b")
    self.assertProtoEquals("""
      name: 'b' op: 'AttrEmptyListDefault'
      attr { key: 'a' value { list { } } }
      """, op.node_def)

    op = self._lib.apply_op("AttrEmptyListDefault", a=[3], name="a")
    self.assertProtoEquals("""
      name: 'a' op: 'AttrEmptyListDefault'
      attr { key: 'a' value { list { f: 3 } } }
      """, op.node_def)

    op = self._lib.apply_op("AttrEmptyListDefault", a=[], name="empty")
    self.assertProtoEquals("""
      name: 'empty' op: 'AttrEmptyListDefault'
      attr { key: 'a' value { list { } } }
      """, op.node_def)

  def testReservedAttr(self):
    self._add_op("name: 'ReservedAttr' "
                 "attr { name: 'range' type: 'int' } ")
    op = self._lib.apply_op("ReservedAttr", range_=7, name="x")
    self.assertProtoEquals("""
      name: 'x' op: 'ReservedAttr' attr { key: 'range' value { i: 7 } }
      """, op.node_def)

  def testNIntsIn(self):
    self._add_op("name: 'NIntsIn' "
                 "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
                 "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")

    op = self._lib.apply_op("NIntsIn", a=[1, 2], name="n")
    self.assertProtoEquals("""
      name: 'n' op: 'NIntsIn' input: 'n/a_0' input: 'n/a_1'
      attr { key: 'N' value { i: 2 } }
      """, op.node_def)

    op = self._lib.apply_op("NIntsIn", a=[5, 4, 3, 2, 1], name="o")
    self.assertProtoEquals("""
      name: 'o' op: 'NIntsIn'
      input: 'o/a_0' input: 'o/a_1' input: 'o/a_2' input: 'o/a_3' input: 'o/a_4'
      attr { key: 'N' value { i: 5 } }
      """, op.node_def)

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("NIntsIn", a=["foo", "bar"])
    self.assertEqual(str(cm.exception),
                     "Tensors in list passed to 'a' of 'NIntsIn' Op have types "
                     "[string, string] that do not match expected type int32.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("NIntsIn",
                         a=[self.Tensor(dtypes.string),
                            self.Tensor(dtypes.string)])
    self.assertEqual(str(cm.exception),
                     "Tensors in list passed to 'a' of 'NIntsIn' Op have "
                     "types [string, string] that do not match expected type "
                     "int32.")

    with self.assertRaises(ValueError) as cm:
      self._lib.apply_op("NIntsIn", a=[99])
    self.assertEqual(str(cm.exception),
                     "List argument 'a' to 'NIntsIn' Op "
                     "with length 1 shorter than "
                     "minimum length 2.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("NIntsIn", a=[38, "bar"])
    self.assertEqual(str(cm.exception),
                     "Tensors in list passed to 'a' of 'NIntsIn' Op have types "
                     "[int32, string] that do not match expected type int32.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("NIntsIn",
                         a=[self.Tensor(dtypes.int32),
                            self.Tensor(dtypes.string)])
    self.assertEqual(str(cm.exception),
                     "Tensors in list passed to 'a' of 'NIntsIn' Op "
                     "have types [int32, string] that do not match expected "
                     "type int32.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("NIntsIn", a=17)
    self.assertStartsWith(str(cm.exception),
                          "Expected list for 'a' argument "
                          "to 'NIntsIn' Op, not ")

  def testNPolymorphicIn(self):
    self._add_op("name: 'NPolymorphicIn' "
                 "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
                 "attr { name: 'T' type: 'type' } "
                 "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")

    op = self._lib.apply_op("NPolymorphicIn", a=[1, 2], name="n")
    self.assertProtoEquals("""
      name: 'n' op: 'NPolymorphicIn' input: 'n/a_0' input: 'n/a_1'
      attr { key: 'T' value { type: DT_INT32 } }
      attr { key: 'N' value { i: 2 } }
      """, op.node_def)

    op = self._lib.apply_op("NPolymorphicIn", a=[5, 4, 3, 2, 1], name="o")
    self.assertProtoEquals("""
      name: 'o' op: 'NPolymorphicIn'
      input: 'o/a_0' input: 'o/a_1' input: 'o/a_2' input: 'o/a_3' input: 'o/a_4'
      attr { key: 'T' value { type: DT_INT32 } }
      attr { key: 'N' value { i: 5 } }
      """, op.node_def)

    op = self._lib.apply_op("NPolymorphicIn", a=["foo", "bar"], name="p")
    self.assertProtoEquals("""
      name: 'p' op: 'NPolymorphicIn' input: 'p/a_0' input: 'p/a_1'
      attr { key: 'T' value { type: DT_STRING } }
      attr { key: 'N' value { i: 2 } }
      """, op.node_def)

    op = self._lib.apply_op("NPolymorphicIn",
                            a=[1, self.Tensor(dtypes.float32, name="x")],
                            name="q")
    self.assertProtoEquals("""
      name: 'q' op: 'NPolymorphicIn' input: 'q/a_0' input: 'x'
      attr { key: 'T' value { type: DT_FLOAT } }
      attr { key: 'N' value { i: 2 } }
      """, op.node_def)

    op = self._lib.apply_op("NPolymorphicIn",
                            a=[self.Tensor(dtypes.float32, name="y"),
                               self.Tensor(dtypes.float32_ref, name="z")],
                            name="r")
    self.assertProtoEquals("""
      name: 'r' op: 'NPolymorphicIn' input: 'y' input: 'z'
      attr { key: 'T' value { type: DT_FLOAT } }
      attr { key: 'N' value { i: 2 } }
      """, op.node_def)

    with self.assertRaises(ValueError) as cm:
      self._lib.apply_op("NPolymorphicIn", a=[99])
    self.assertEqual(str(cm.exception),
                     "List argument 'a' to 'NPolymorphicIn' Op with length 1 "
                     "shorter than minimum length 2.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("NPolymorphicIn", a=[38, "bar"])
    self.assertEqual(str(cm.exception),
                     "Tensors in list passed to 'a' of 'NPolymorphicIn' Op "
                     "have types [int32, string] that don't all match.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("NPolymorphicIn", a=[38, self.Tensor(dtypes.string)])
    self.assertEqual(str(cm.exception),
                     "Tensors in list passed to 'a' of 'NPolymorphicIn' Op "
                     "have types [int32, string] that don't all match.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("NPolymorphicIn", a=[38, None])
    self.assertEqual(str(cm.exception),
                     "Tensors in list passed to 'a' of 'NPolymorphicIn' Op "
                     "have types [int32, <NOT CONVERTIBLE TO TENSOR>] that "
                     "don't all match.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("NPolymorphicIn",
                         a=["abcd", self.Tensor(dtypes.int32)])
    self.assertEqual(str(cm.exception),
                     "Tensors in list passed to 'a' of 'NPolymorphicIn' Op "
                     "have types [string, int32] that don't all match.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("NPolymorphicIn", a=17)
    self.assertStartsWith(str(cm.exception),
                          "Expected list for 'a' argument "
                          "to 'NPolymorphicIn' Op, not ")

  def testNPolymorphicRestrictIn(self):
    self._add_op("name: 'NPolymorphicRestrictIn' "
                 "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
                 "attr { name: 'T' type: 'type' allowed_values { "
                 "  list { type: DT_STRING type: DT_BOOL } } } "
                 "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")

    op = self._lib.apply_op("NPolymorphicRestrictIn", a=["foo", "bar"],
                            name="p")
    self.assertProtoEquals("""
      name: 'p' op: 'NPolymorphicRestrictIn' input: 'p/a_0' input: 'p/a_1'
      attr { key: 'T' value { type: DT_STRING } }
      attr { key: 'N' value { i: 2 } }
      """, op.node_def)

    op = self._lib.apply_op("NPolymorphicRestrictIn",
                            a=[False, True, False],
                            name="b")
    self.assertProtoEquals("""
      name: 'b' op: 'NPolymorphicRestrictIn'
      input: 'b/a_0' input: 'b/a_1' input: 'b/a_2'
      attr { key: 'T' value { type: DT_BOOL } }
      attr { key: 'N' value { i: 3 } }
      """, op.node_def)

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("NPolymorphicRestrictIn", a=[1, 2])
    self.assertEqual(str(cm.exception),
                     "DataType int32 for attr 'T' "
                     "not in list of allowed values: string, bool")

  def testNInTwice(self):
    self._add_op("name: 'NInTwice' "
                 "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
                 "input_arg { name: 'b' type: DT_STRING number_attr: 'N' } "
                 "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 }")

    op = self._lib.apply_op("NInTwice", a=[1, 2], b=["one", "two"], name="n")
    self.assertProtoEquals("""
      name: 'n' op: 'NInTwice'
      input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1'
      attr { key: 'N' value { i: 2 } }
      """, op.node_def)

    op = self._lib.apply_op("NInTwice", a=[], b=[], name="o")
    self.assertProtoEquals("""
      name: 'o' op: 'NInTwice' attr { key: 'N' value { i: 0 } }
      """, op.node_def)

    with self.assertRaises(ValueError) as cm:
      self._lib.apply_op("NInTwice", a=[1, 2, 3], b=["too short"])
    self.assertEqual(str(cm.exception),
                     "List argument 'b' to 'NInTwice' Op "
                     "with length 1 must match "
                     "length 3 of argument 'a'.")

  def testNInPolymorphicTwice(self):
    self._add_op("name: 'NInPolymorphicTwice' "
                 "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
                 "input_arg { name: 'b' type_attr: 'T' number_attr: 'N' } "
                 "attr { name: 'T' type: 'type' } "
                 "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 }")

    op = self._lib.apply_op("NInPolymorphicTwice", a=[1, 2], b=[3, 4], name="n")
    self.assertProtoEquals("""
      name: 'n' op: 'NInPolymorphicTwice'
      input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1'
      attr { key: 'T' value { type: DT_INT32 } }
      attr { key: 'N' value { i: 2 } }
      """, op.node_def)

    with self.assertRaises(ValueError) as cm:
      self._lib.apply_op("NInPolymorphicTwice", a=[1, 2, 3], b=[5])
    self.assertEqual(str(cm.exception),
                     "List argument 'b' to 'NInPolymorphicTwice' Op "
                     "with length 1 "
                     "must match length 3 of argument 'a'.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("NInPolymorphicTwice", a=[1, 2], b=["one", "two"])
    self.assertEqual(str(cm.exception),
                     "Tensors in list passed to 'b' of 'NInPolymorphicTwice' "
                     "Op have types [string, string] that do not match type "
                     "int32 inferred from earlier arguments.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("NInPolymorphicTwice",
                         a=[self.Tensor(dtypes.int32)],
                         b=[self.Tensor(dtypes.string)])
    self.assertEqual(str(cm.exception),
                     "Tensors in list passed to 'b' of "
                     "'NInPolymorphicTwice' Op have types [string] that do not "
                     "match type int32 inferred from earlier arguments.")

  def testNInTwoTypeVariables(self):
    self._add_op("name: 'NInTwoTypeVariables' "
                 "input_arg { name: 'a' type_attr: 'S' number_attr: 'N' } "
                 "input_arg { name: 'b' type_attr: 'T' number_attr: 'N' } "
                 "attr { name: 'S' type: 'type' } "
                 "attr { name: 'T' type: 'type' } "
                 "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 }")

    op = self._lib.apply_op("NInTwoTypeVariables",
                            a=[1, 2],
                            b=[True, False],
                            name="n")
    self.assertProtoEquals("""
      name: 'n' op: 'NInTwoTypeVariables'
      input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1'
      attr { key: 'S' value { type: DT_INT32 } }
      attr { key: 'T' value { type: DT_BOOL } }
      attr { key: 'N' value { i: 2 } }
      """, op.node_def)

    op = self._lib.apply_op("NInTwoTypeVariables", a=[1, 2], b=[3, 4], name="o")
    self.assertProtoEquals("""
      name: 'o' op: 'NInTwoTypeVariables'
      input: 'o/a_0' input: 'o/a_1' input: 'o/b_0' input: 'o/b_1'
      attr { key: 'S' value { type: DT_INT32 } }
      attr { key: 'T' value { type: DT_INT32 } }
      attr { key: 'N' value { i: 2 } }
      """, op.node_def)

    op = self._lib.apply_op("NInTwoTypeVariables",
                            a=[self.Tensor(dtypes.int32, name="q")],
                            b=[self.Tensor(dtypes.string, name="r")],
                            name="p")
    self.assertProtoEquals("""
      name: 'p' op: 'NInTwoTypeVariables' input: 'q' input: 'r'
      attr { key: 'S' value { type: DT_INT32 } }
      attr { key: 'T' value { type: DT_STRING } }
      attr { key: 'N' value { i: 1 } }
      """, op.node_def)

    with self.assertRaises(ValueError) as cm:
      self._lib.apply_op("NInTwoTypeVariables", a=[1, 2, 3], b=["5"])
    self.assertEqual(str(cm.exception),
                     "List argument 'b' to 'NInTwoTypeVariables' Op "
                     "with length 1 "
                     "must match length 3 of argument 'a'.")

  def testInPolymorphicTwice(self):
    self._add_op("name: 'InPolymorphicTwice' "
                 "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
                 "input_arg { name: 'b' type_attr: 'T' number_attr: 'M' } "
                 "attr { name: 'T' type: 'type' } "
                 "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 } "
                 "attr { name: 'M' type: 'int' has_minimum: true minimum: 0 } ")

    op = self._lib.apply_op("InPolymorphicTwice", a=[8], b=[3, 4, 5], name="n")
    self.assertProtoEquals("""
      name: 'n' op: 'InPolymorphicTwice'
      input: 'n/a_0' input: 'n/b_0' input: 'n/b_1' input: 'n/b_2'
      attr { key: 'T' value { type: DT_INT32 } }
      attr { key: 'N' value { i: 1 } }
      attr { key: 'M' value { i: 3 } }
      """, op.node_def)

    op = self._lib.apply_op("InPolymorphicTwice", a=[8], b=[], name="o")
    self.assertProtoEquals("""
      name: 'o' op: 'InPolymorphicTwice' input: 'o/a_0'
      attr { key: 'T' value { type: DT_INT32 } }
      attr { key: 'N' value { i: 1 } }
      attr { key: 'M' value { i: 0 } }
      """, op.node_def)

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("InPolymorphicTwice", a=[], b=[3, 4, 5])
    self.assertEqual(str(cm.exception),
                     "Don't know how to infer type variable from empty input "
                     "list passed to input 'a' of 'InPolymorphicTwice' Op.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("InPolymorphicTwice", a=[1, 2], b=["one", "two"])
    self.assertEqual(str(cm.exception),
                     "Tensors in list passed to 'b' of 'InPolymorphicTwice' Op "
                     "have types [string, string] that do not match type int32 "
                     "inferred from earlier arguments.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("InPolymorphicTwice",
                         a=[self.Tensor(dtypes.int32)],
                         b=[self.Tensor(dtypes.string)])
    self.assertEqual(str(cm.exception),
                     "Tensors in list passed to 'b' of 'InPolymorphicTwice' "
                     "Op have types [string] that do not match type int32 "
                     "inferred from earlier arguments.")

  def testNIntsOut(self):
    self._add_op("name: 'NIntsOut' "
                 "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
                 "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")

    out1, out2 = self._lib.apply_op("NIntsOut", N=2, name="n")
    self.assertEqual(dtypes.int32, out1.dtype)
    self.assertEqual(dtypes.int32, out2.dtype)
    self.assertProtoEquals("""
      name: 'n' op: 'NIntsOut' attr { key: 'N' value { i: 2 } }
      """, out1.op.node_def)

    out1, out2, out3, out4, out5 = self._lib.apply_op(
        "NIntsOut", N=5, name="o")
    self.assertEqual(dtypes.int32, out1.dtype)
    self.assertEqual(dtypes.int32, out2.dtype)
    self.assertEqual(dtypes.int32, out3.dtype)
    self.assertEqual(dtypes.int32, out4.dtype)
    self.assertEqual(dtypes.int32, out5.dtype)
    self.assertProtoEquals("""
      name: 'o' op: 'NIntsOut' attr { key: 'N' value { i: 5 } }
      """, out5.op.node_def)

    with self.assertRaises(ValueError) as cm:
      self._lib.apply_op("NIntsOut", N=1)
    self.assertEqual(str(cm.exception),
                     "Attr 'N' of 'NIntsOut' Op passed 1 less than minimum 2.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("NIntsOut", N=[3])
    self.assertEqual(str(cm.exception),
                     "Expected int for argument 'N' not [3].")

  def testNIntsOutDefault(self):
    self._add_op("name: 'NIntsOutDefault' "
                 "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
                 "attr { name: 'N' type: 'int' has_minimum: true minimum: 2"
                 "  default_value { i:3 } }")

    out1, out2, out3 = self._lib.apply_op(
        "NIntsOutDefault", N=None, name="z")
    self.assertEqual(dtypes.int32, out1.dtype)
    self.assertEqual(dtypes.int32, out2.dtype)
    self.assertEqual(dtypes.int32, out3.dtype)
    self.assertProtoEquals("""
      name: 'z' op: 'NIntsOutDefault' attr { key: 'N' value { i: 3 } }
      """, out1.op.node_def)

    out1, out2 = self._lib.apply_op("NIntsOutDefault", N=2, name="y")
    self.assertEqual(dtypes.int32, out1.dtype)
    self.assertEqual(dtypes.int32, out2.dtype)
    self.assertProtoEquals("""
      name: 'y' op: 'NIntsOutDefault' attr { key: 'N' value { i: 2 } }
      """, out2.op.node_def)

  def testNPolymorphicOut(self):
    self._add_op("name: 'NPolymorphicOut' "
                 "output_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
                 "attr { name: 'T' type: 'type' } "
                 "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")

    out1, out2 = self._lib.apply_op("NPolymorphicOut",
                                    N=2,
                                    T=dtypes.int32,
                                    name="n")
    self.assertEqual(dtypes.int32, out1.dtype)
    self.assertEqual(dtypes.int32, out2.dtype)
    self.assertProtoEquals("""
      name: 'n' op: 'NPolymorphicOut'
      attr { key: 'T' value { type: DT_INT32 } }
      attr { key: 'N' value { i: 2 } }
      """, out1.op.node_def)

    out1, out2, out3 = self._lib.apply_op(
        "NPolymorphicOut", T=dtypes.string, N=3, name="o")
    self.assertEqual(dtypes.string, out1.dtype)
    self.assertEqual(dtypes.string, out2.dtype)
    self.assertEqual(dtypes.string, out3.dtype)
    self.assertProtoEquals("""
      name: 'o' op: 'NPolymorphicOut'
      attr { key: 'T' value { type: DT_STRING } }
      attr { key: 'N' value { i: 3 } }
      """, out3.op.node_def)

    with self.assertRaises(ValueError) as cm:
      self._lib.apply_op("NPolymorphicOut", N=1, T=dtypes.string)
    self.assertEqual(str(cm.exception),
                     "Attr 'N' of 'NPolymorphicOut' Op "
                     "passed 1 less than minimum 2.")

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("NPolymorphicOut", N=3, T=[dtypes.string])
    self.assertEqual(
        str(cm.exception),
        "Expected DataType for argument 'T' not [tf.string].")

  def testNPolymorphicOutDefault(self):
    self._add_op("name: 'NPolymorphicOutDefault' "
                 "output_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
                 "attr { name: 'T' type: 'type'"
                 "  default_value { type: DT_BOOL } } "
                 "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 "
                 "  default_value { i: 2 } }")

    out1, out2 = self._lib.apply_op(
        "NPolymorphicOutDefault", N=None, T=None, name="r")
    self.assertEqual(dtypes.bool, out1.dtype)
    self.assertEqual(dtypes.bool, out2.dtype)
    self.assertProtoEquals("""
      name: 'r' op: 'NPolymorphicOutDefault'
      attr { key: 'T' value { type: DT_BOOL } }
      attr { key: 'N' value { i: 2 } }
      """, out1.op.node_def)

    out1, out2, out3 = self._lib.apply_op(
        "NPolymorphicOutDefault", N=3, T=None, name="s")
    self.assertEqual(dtypes.bool, out1.dtype)
    self.assertEqual(dtypes.bool, out2.dtype)
    self.assertEqual(dtypes.bool, out3.dtype)
    self.assertProtoEquals("""
      name: 's' op: 'NPolymorphicOutDefault'
      attr { key: 'T' value { type: DT_BOOL } }
      attr { key: 'N' value { i: 3 } }
      """, out1.op.node_def)

    out1, out2 = self._lib.apply_op(
        "NPolymorphicOutDefault", N=None, T=dtypes.int32, name="t")
    self.assertEqual(dtypes.int32, out1.dtype)
    self.assertEqual(dtypes.int32, out2.dtype)
    self.assertProtoEquals("""
      name: 't' op: 'NPolymorphicOutDefault'
      attr { key: 'T' value { type: DT_INT32 } }
      attr { key: 'N' value { i: 2 } }
      """, out1.op.node_def)

    out1, out2, out3 = self._lib.apply_op(
        "NPolymorphicOutDefault", N=3, T=dtypes.int32, name="u")
    self.assertEqual(dtypes.int32, out1.dtype)
    self.assertEqual(dtypes.int32, out2.dtype)
    self.assertEqual(dtypes.int32, out3.dtype)
    self.assertProtoEquals("""
      name: 'u' op: 'NPolymorphicOutDefault'
      attr { key: 'T' value { type: DT_INT32 } }
      attr { key: 'N' value { i: 3 } }
      """, out1.op.node_def)

  def testNPolymorphicRestrictOut(self):
    self._add_op("name: 'NPolymorphicRestrictOut' "
                 "output_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
                 "attr { name: 'T' type: 'type' allowed_values { "
                 "  list { type: DT_STRING type: DT_BOOL } } } "
                 "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")

    out1, out2, out3 = self._lib.apply_op(
        "NPolymorphicRestrictOut", N=3, T=dtypes.bool, name="u")
    self.assertEqual(dtypes.bool, out1.dtype)
    self.assertEqual(dtypes.bool, out2.dtype)
    self.assertEqual(dtypes.bool, out3.dtype)
    self.assertProtoEquals("""
      name: 'u' op: 'NPolymorphicRestrictOut'
      attr { key: 'T' value { type: DT_BOOL } }
      attr { key: 'N' value { i: 3 } }
      """, out1.op.node_def)

    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("NPolymorphicRestrictOut", N=2, T=dtypes.int32)
    self.assertEqual(str(cm.exception),
                     "DataType int32 for attr 'T' "
                     "not in list of allowed values: string, bool")

  def testRef(self):
    self._add_op("name: 'RefIn' "
                 "input_arg { name: 'a' type_attr: 'T' is_ref: true } "
                 "attr { name: 'T' type: 'type' } ")
    self._add_op("name: 'TwoRefsIn' "
                 "input_arg { name: 'a' type_attr: 'T' is_ref: true } "
                 "input_arg { name: 'b' type_attr: 'T' is_ref: true } "
                 "attr { name: 'T' type: 'type' } ")
    self._add_op("name: 'RefOut' "
                 "output_arg { name: 'a' type_attr: 'T' is_ref: true } "
                 "attr { name: 'T' type: 'type' } ")

    out = self._lib.apply_op("RefOut", T=dtypes.bool, name="o")
    self.assertEqual(dtypes.bool_ref, out.dtype)
    self.assertProtoEquals("""
      name: 'o' op: 'RefOut'
      attr { key: 'T' value { type: DT_BOOL } }
      """, out.op.node_def)

    op = self._lib.apply_op("RefIn", a=out, name="i")
    self.assertProtoEquals("""
      name: 'i' op: 'RefIn' input: 'o'
      attr { key: 'T' value { type: DT_BOOL } }
      attr { key: "_class" value { list { s: "loc:@o" } } }
      """, op.node_def)

    # Can pass ref to non-ref input.
    out = self._lib.apply_op("RefOut", T=dtypes.int32, name="r")
    out = self._lib.apply_op("Simple", a=out, name="s")
    self.assertProtoEquals("""
      name: 's' op: 'Simple' input: 'r'
      """, out.op.node_def)

    # Can't pass non-ref to ref input.
    with self.assertRaises(TypeError) as cm:
      self._lib.apply_op("RefIn", a=2)
    self.assertEqual(str(cm.exception),
                     "Input 'a' of 'RefIn' Op requires l-value input")

    input_a = self._lib.apply_op("RefOut", T=dtypes.int32, name="t")
    input_b = self._lib.apply_op("RefOut", T=dtypes.int32, name="u")
    op = self._lib.apply_op("TwoRefsIn", a=input_a, b=input_b, name="v")
    # NOTE(mrry): The order of colocation constraints is an implementation
    # detail.
    self.assertProtoEquals("""
      name: 'v' op: 'TwoRefsIn' input: 't' input: 'u'
      attr { key: 'T' value { type: DT_INT32 } }
      attr { key: "_class" value { list { s: "loc:@t" s: "loc:@u" } } }
      """, op.node_def)

  def testSpecifyDevice(self):
    with self._g.device("/job:ADevice"):
      self._lib.apply_op("Simple", a=3)
    # We look at the whole graph here to make sure the Const op is also given
    # the specified device.
    graph_def = self._g.as_graph_def()
    self.assertEqual(len(graph_def.node), 2)
    for node in graph_def.node:
      self.assertDeviceEqual(node.device, "/job:ADevice")

  def testStructuredOutputSingleList(self):
    self._add_op("name: 'SimpleStruct' "
                 "output_arg { name: 'a' type: DT_INT32 number_attr: 'n_a' } "
                 "attr { name: 'n_a' type: 'int' }")
    for n_a in [0, 1, 3]:
      a = self._lib.apply_op("SimpleStruct", n_a=n_a)
      self.assertTrue(isinstance(a, list))
      self.assertEqual(n_a, len(a))

  def testStructuredOutputListAndSingle(self):
    self._add_op("name: 'MixedStruct' "
                 "output_arg { name: 'a' type: DT_INT32 number_attr: 'n_a' } "
                 "output_arg { name: 'b' type: DT_FLOAT } "
                 "attr { name: 'n_a' type: 'int' }")
    for n_a in [0, 1, 3]:
      a, b = self._lib.apply_op("MixedStruct", n_a=n_a)
      self.assertTrue(isinstance(a, list))
      self.assertEqual(n_a, len(a))
      self.assertTrue(all(x.dtype == dtypes.int32 for x in a))
      self.assertTrue(isinstance(b, ops.Tensor))
      self.assertEqual(dtypes.float32, b.dtype)

  def testStructuredOutputMultipleLists(self):
    self._add_op("name: 'ComplexStruct' "
                 "output_arg { name: 'a' type: DT_INT32 number_attr: 'n_a' } "
                 "output_arg { name: 'b' type: DT_INT64 number_attr: 'n_b' } "
                 "output_arg { name: 'c' type_list_attr: 't_c' } "
                 "attr { name: 'n_a' type: 'int' } "
                 "attr { name: 'n_b' type: 'int' } "
                 "attr { name: 't_c' type: 'list(type)' }")
    for n_a in [0, 1, 3]:
      for n_b in [0, 1, 3]:
        for t_c in [[],
                    [dtypes.int32],
                    [dtypes.int32, dtypes.float32]]:
          a, b, c = self._lib.apply_op("ComplexStruct",
                                       n_a=n_a,
                                       n_b=n_b,
                                       t_c=t_c)

          self.assertEqual(n_a, len(a))
          self.assertTrue(all(x.dtype == dtypes.int32 for x in a))
          self.assertEqual(n_b, len(b))
          self.assertTrue(all(x.dtype == dtypes.int64 for x in b))
          self.assertEqual(t_c, [x.dtype for x in c])


class OpDefLibraryGraphTest(test_util.TensorFlowTestCase):

  def setUp(self):
    self._lib = OpDefLibrary()
    self._g = ops.Graph()
    self._add_op("name: 'Simple' input_arg { name: 'a' type: DT_INT32 } "
                 "output_arg { name: 'out' type: DT_FLOAT }")
    self._add_op("name: 'Binary' "
                 "input_arg { name: 'a' type_attr: 'T' } "
                 "input_arg { name: 'b' type_attr: 'T' } "
                 "output_arg { name: 'out' type_attr: 'T' } "
                 "attr { name: 'T' type: 'type' }")

  def _add_op(self, ascii):
    op_def = op_def_pb2.OpDef()
    text_format.Merge(ascii, op_def)
    self._lib.add_op(op_def)

  def testNoGraph(self):
    out = self._lib.apply_op("Simple", a=3)
    self.assertEqual(out.graph, ops.get_default_graph())

  def testDefaultGraph(self):
    with self._g.as_default():
      out = self._lib.apply_op("Simple", a=3)
      self.assertEqual(out.graph, self._g)

  def testDifferentGraphFails(self):
    with self._g.as_default():
      a = self._lib.apply_op("Simple", a=3)
    other_g = ops.Graph()
    with other_g.as_default():
      b = self._lib.apply_op("Simple", a=4)
    with self.assertRaises(ValueError) as cm:
      self._lib.apply_op("Binary", a=a, b=b)
    self.assertTrue("must be from the same graph" in str(cm.exception))


if __name__ == "__main__":
  googletest.main()
